import os
import argparse
import torch
import numpy as np

from PIL import Image
import time

import torchvision
from torchvision import transforms as transforms

parser = argparse.ArgumentParser(description='pretrained mobilenet')
parser.add_argument('--input_txt', default='', type=str, help='dataset path')
parser.add_argument('--save_folder', default='./testout/', type=str, help='Dir to save txt results')
parser.add_argument('--cpu', action="store_true", default=False, help='Use cpu inference')
args = parser.parse_args()


if __name__ == '__main__':
    torch.set_grad_enabled(False)
    # net and model
    model = torchvision.models.mobilenet_v2(pretrained=True)
    model.eval()
    print('Finished loading model!')
    print(model)
    # cudnn.benchmark = True
    device = torch.device("cpu" if args.cpu else "cuda")
    model = model.to(device)


    if args.input_txt != '':
        f_list = open(args.input_txt, 'r')
        test_dataset = f_list.readlines()
        test_dataset = [_d.strip() for _d in test_dataset]
    print(f'done preparing dataset!, N={len(test_dataset)}', flush=True)
    num_images = len(test_dataset)

    preprocess = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    # testing begin
    for i, image_path in enumerate(test_dataset):
        now = time.time()
        img = Image.open(image_path).convert('RGB')
        img_name = image_path.split('/')[-1]
        subfolder = image_path.split('/')[-2]
        save_name = os.path.join(args.save_folder, subfolder, img_name.split('.')[0] + ".txt")
        dirname = os.path.dirname(save_name)
        if not os.path.isdir(dirname):
            os.makedirs(dirname)
        if not os.path.exists(save_name):
            img = preprocess(img)
            img = img.to(device)
            probs = model(img.unsqueeze(0))[0].softmax(-1).cpu().numpy()
            with open(save_name, "w") as fd:
                fd.write(",".join(map(str, probs)))

        if i % 10 == 0:
            print(f"im_detect: {i + 1:5}/{num_images} Time: {(time.time()-now):.3f}s",
                  f"== {1./(time.time()-now ) :.1f}Hz", flush=True)